import os
import json
import argparse

from model import DecoderBase, make_model
from data import get_bigcodebench, write_jsonl
import gc
import torch
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel


def codegen(
    model: DecoderBase,
    save_path: str,
    split: str,
    subset="full",
    greedy=False,
    strip_newlines=False,
    n_samples=1,
    id_range=None,
    resume=True,
):

    dataset = get_bigcodebench(subset=subset)

    if model.is_direct_completion() and split == "instruct":
        raise Exception("Base model does not support direct completion for instruct tasks")

    # create save_path if it doesn't exist, e.g., a/b.jsonl
    dirname = os.path.dirname(save_path)
    if not os.path.exists(dirname) and dirname != "":
        os.makedirs(dirname)

    task_ids, prompts, complete_prompts = [], [], []
    for task_id, task in dataset.items():

        task_ids.append(task_id)
        complete_prompts.append(task["complete_prompt"])
        prompt = task[f"{split}_prompt"]
        prompt = prompt.strip("\n") if strip_newlines else prompt
        prompts.append(prompt)

    outputs = model.codegen(prompts, do_sample=not greedy, num_samples=n_samples)
    assert outputs, "No outputs from model!"

    samples = []
    for task_id, complete_prompt, completion in zip(task_ids, complete_prompts, outputs):
        if model.is_direct_completion():
            generated = complete_prompt + completion
        else:
            generated = completion
        samples.append(dict(task_id=task_id, solution=generated))
        # print(f"[{generated}]")

    print(f"Generated {len(samples)} samples")
    write_jsonl(save_path, samples)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, type=str)
    parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"])
    parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"])
    parser.add_argument("--save_path", default=None, type=str)
    parser.add_argument("--bs", default=1, type=int)
    parser.add_argument("--n_samples", default=1, type=int)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--greedy", action="store_true")
    parser.add_argument("--strip_newlines", action="store_true")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--chat_mode", action="store_true", default=False)
    parser.add_argument("--id_range", nargs=2, type=int)
    parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"])
    parser.add_argument("--base_url", default=None, type=str)
    parser.add_argument("--tp", default=1, type=int)
    parser.add_argument("--tokenizer_legacy", action="store_true")
    parser.add_argument("--tokenizer_name", default=None, type=str)

    args = parser.parse_args()

    if args.greedy or (args.temperature == 0 and args.n_samples == 1):
        args.temperature = 0
        args.bs = 1
        args.n_samples = 1
        args.greedy = True
        print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")

    if args.id_range is not None:
        assert len(args.id_range) == 2, "id_range must be a list of length 2"
        assert args.id_range[0] < args.id_range[1], "id_range must be increasing"
        args.id_range = tuple(args.id_range)

    # Make dir for codes generated by each model
    model_runner = make_model(
        model=args.model,
        backend=args.backend,
        batch_size=args.bs,
        temperature=args.temperature,
        base_url=args.base_url,
        tp=args.tp,
        tokenizer_name=args.tokenizer_name,
        tokenizer_legacy=args.tokenizer_legacy,
        chat_mode=args.chat_mode,
    )

    extra = "-" + args.subset if args.subset != "full" else ""
    if not args.save_path:
        save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
    else:
        save_path = args.save_path

    codegen(
        model=model_runner,
        save_path=save_path,
        split=args.split,
        subset=args.subset,
        greedy=args.greedy,
        strip_newlines=args.strip_newlines,
        n_samples=args.n_samples,
        resume=args.resume,
        id_range=args.id_range,
    )

    if args.backend == "vllm":
        print(f"Try cleanup...")
        destroy_model_parallel()
        destroy_distributed_environment()
        del model_runner.llm.llm_engine.model_executor
        del model_runner.llm
        gc.collect()
        torch.cuda.empty_cache()

    print(f"===============================")
    print(f"----- END OF BigCodeBench -----")


if __name__ == "__main__":
    main()
